{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join as pjoin\n",
    "from tqdm import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.animation import FuncAnimation, PillowWriter\n",
    "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
    "import mpl_toolkits.mplot3d.axes3d as p3\n",
    "def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):\n",
    "#     matplotlib.use('Agg')\n",
    "\n",
    "    title_sp = title.split(' ')\n",
    "    if len(title_sp) > 10:\n",
    "        title = '\\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])\n",
    "    def init():\n",
    "        ax.set_xlim3d([-radius / 2, radius / 2])\n",
    "        ax.set_ylim3d([0, radius])\n",
    "        ax.set_zlim3d([0, radius])\n",
    "        # print(title)\n",
    "        fig.suptitle(title, fontsize=20)\n",
    "        ax.grid(b=False)\n",
    "\n",
    "    def plot_xzPlane(minx, maxx, miny, minz, maxz):\n",
    "        ## Plot a plane XZ\n",
    "        verts = [\n",
    "            [minx, miny, minz],\n",
    "            [minx, miny, maxz],\n",
    "            [maxx, miny, maxz],\n",
    "            [maxx, miny, minz]\n",
    "        ]\n",
    "        xz_plane = Poly3DCollection([verts])\n",
    "        xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))\n",
    "        ax.add_collection3d(xz_plane)\n",
    "\n",
    "    #         return ax\n",
    "\n",
    "    # (seq_len, joints_num, 3)\n",
    "    data = joints.copy().reshape(len(joints), -1, 3)\n",
    "    fig = plt.figure(figsize=figsize)\n",
    "    ax = p3.Axes3D(fig)\n",
    "    init()\n",
    "    MINS = data.min(axis=0).min(axis=0)\n",
    "    MAXS = data.max(axis=0).max(axis=0)\n",
    "    colors = ['red', 'blue', 'black', 'red', 'blue',  \n",
    "              'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',\n",
    "             'darkred', 'darkred','darkred','darkred','darkred']\n",
    "    frame_number = data.shape[0]\n",
    "    #     print(data.shape)\n",
    "\n",
    "    height_offset = MINS[1]\n",
    "    data[:, :, 1] -= height_offset\n",
    "    trajec = data[:, 0, [0, 2]]\n",
    "    \n",
    "    data[..., 0] -= data[:, 0:1, 0]\n",
    "    data[..., 2] -= data[:, 0:1, 2]\n",
    "\n",
    "    #     print(trajec.shape)\n",
    "\n",
    "    def update(index):\n",
    "        #         print(index)\n",
    "        ax.lines = []\n",
    "        ax.collections = []\n",
    "        ax.view_init(elev=120, azim=-90)\n",
    "        ax.dist = 7.5\n",
    "        #         ax =\n",
    "        plot_xzPlane(MINS[0]-trajec[index, 0], MAXS[0]-trajec[index, 0], 0, MINS[2]-trajec[index, 1], MAXS[2]-trajec[index, 1])\n",
    "#         ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)\n",
    "        \n",
    "        if index > 1:\n",
    "            ax.plot3D(trajec[:index, 0]-trajec[index, 0], np.zeros_like(trajec[:index, 0]), trajec[:index, 1]-trajec[index, 1], linewidth=1.0,\n",
    "                      color='blue')\n",
    "        #             ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])\n",
    "        \n",
    "        \n",
    "        for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):\n",
    "#             print(color)\n",
    "            if i < 5:\n",
    "                linewidth = 4.0\n",
    "            else:\n",
    "                linewidth = 2.0\n",
    "            ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, color=color)\n",
    "        #         print(trajec[:index, 0].shape)\n",
    "\n",
    "        plt.axis('off')\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_yticklabels([])\n",
    "        ax.set_zticklabels([])\n",
    "\n",
    "    ani = FuncAnimation(fig, update, frames=frame_number, interval=1000/fps, repeat=False)\n",
    "\n",
    "    ani.save(save_path, fps=fps)\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_dir = './HumanML3D/new_joints/'\n",
    "tgt_ani_dir = \"./HumanML3D/animations/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]\n",
    "os.makedirs(tgt_ani_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "npy_files = os.listdir(src_dir)\n",
    "npy_files = sorted(npy_files)\n",
    "# npy_files = npy_files[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This will take a few hours for the whole dataset. Here we show ten animations for an example\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To accelerate the process, you could copy and run this script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:14<00:00,  1.43s/it]\n"
     ]
    }
   ],
   "source": [
    "for npy_file in tqdm(npy_files):\n",
    "    data = np.load(pjoin(src_dir, npy_file))\n",
    "    save_path = pjoin(tgt_ani_dir, npy_file[:-3] + 'mp4')\n",
    "    if os.path.exists(save_path):\n",
    "        continue\n",
    "#   You may set the title on your own.\n",
    "    plot_3d_motion(save_path, kinematic_chain, data, title=\"None\", fps=20, radius=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch_render]",
   "language": "python",
   "name": "conda-env-torch_render-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
